{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# textgenrnn 1.1 Training with Context\n",
    "\n",
    "by [Max Woolf](http://minimaxir.com)\n",
    "\n",
    "*Max's open-source projects are supported by his [Patreon](https://www.patreon.com/minimaxir). If you found this project helpful, any monetary contributions to the Patreon are appreciated and will be put to good creative use.*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Intro\n",
    "\n",
    "textgenrnn is a Python module on top of Keras/TensorFlow which can easily generate text using a pretrained recurrent neural network:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from textgenrnn import textgenrnn\n",
    "\n",
    "textgen = textgenrnn()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train on New Text\n",
    "\n",
    "Here's textgenrnn trained on a dataset of the Top 1000 submissions to /r/politics and /r/rarepuppers during 2017: two subreddits with drastically different styles! How well can a new model learn these styles?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2,000 texts collected.\n",
      "Training new model w/ 2-layer, 128-cell LSTMs\n",
      "Training on 114,560 character sequences.\n",
      "Epoch 1/10\n",
      "895/895 [==============================] - 130s 145ms/step - loss: 2.8652\n",
      "Epoch 2/10\n",
      "895/895 [==============================] - 124s 139ms/step - loss: 2.2462\n",
      "Epoch 3/10\n",
      "895/895 [==============================] - 126s 141ms/step - loss: 1.9843\n",
      "Epoch 4/10\n",
      "895/895 [==============================] - 130s 145ms/step - loss: 1.8106\n",
      "Epoch 5/10\n",
      "895/895 [==============================] - 131s 147ms/step - loss: 1.6756\n",
      "####################\n",
      "Temperature: 0.2\n",
      "####################\n",
      "Trump to the sand to the senator of the senity controver to the protection to says his the senator to to the senator to the sand in the sand to the senator to the senator to thing to the senator in the senator to the senior and to says the senator to the senator to republican says Trump and to the\n",
      "\n",
      "Trump to the sand to to to be to boye boye still boye boye still frens\n",
      "\n",
      "Trump to the sand and to say to the with the son the senator to the senator of the protection to says to contract of Trump to thing man the states\n",
      "\n",
      "####################\n",
      "Temperature: 0.5\n",
      "####################\n",
      "Trump takes to for the US contaction office to the US states history contrapters\n",
      "\n",
      "Shew to the make to boy to recovery frens\n",
      "\n",
      "The doing a heckin good corgo does a heckin good brock on har protect in the U.S.. me to proming a dong and out on historce investigation of Trump to the US got his is contronce\n",
      "\n",
      "####################\n",
      "Temperature: 1.0\n",
      "####################\n",
      "Trump tappe, Donald Trump this soor is blover privves do frean\n",
      "\n",
      "Whan good mage could' voteduce\n",
      "\n",
      "Trump invignians his daye Jarses firm underge ties\n",
      "\n",
      "Epoch 6/10\n",
      "895/895 [==============================] - 124s 139ms/step - loss: 1.5677\n",
      "Epoch 7/10\n",
      "895/895 [==============================] - 124s 139ms/step - loss: 1.4781\n",
      "Epoch 8/10\n",
      "895/895 [==============================] - 128s 143ms/step - loss: 1.4687\n",
      "Epoch 9/10\n",
      "895/895 [==============================] - 129s 144ms/step - loss: 1.3409\n",
      "Epoch 10/10\n",
      "895/895 [==============================] - 127s 142ms/step - loss: 1.2742\n",
      "####################\n",
      "Temperature: 0.2\n",
      "####################\n",
      "Trump the plans to send a president to the president\n",
      "\n",
      "Trump administration for the U.S. of the president in the million is states her don't me a still the still state he dad a heckin big hold streatched to serve protect in the send in the was the down he dad a heckin big hold recovery\n",
      "\n",
      "Trump that he dad a street doing a heckin bamboozle and the sand and the still could resign\n",
      "\n",
      "####################\n",
      "Temperature: 0.5\n",
      "####################\n",
      "The president didn't have to gest be evidence to compared the been president in the health call bill after Trump and dally demanded to their new president was may the was a big how said to be vv working the world riphters himelating adayion and Trump\n",
      "\n",
      "Trump to all will net neutrality comments\n",
      "\n",
      "Trump and President Investigation Protest Trump Kussia Masorbal\n",
      "\n",
      "####################\n",
      "Temperature: 1.0\n",
      "####################\n",
      "as after pupper\n",
      "\n",
      "lott smol pupper does a heckin good boye because news be fake news\n",
      "\n",
      "T O N G I E rev boye bamboozle?\n",
      "\n"
     ]
    }
   ],
   "source": [
    "file_path = \"../datasets/reddit_rarepuppers_politics_2000.txt\"\n",
    "\n",
    "textgen.reset()\n",
    "textgen.train_from_file(file_path, new_model=True, num_epochs=10, gen_epochs=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And here's the same dataset, trained with the context labels indicating the source subreddit. If using `train_from_file`, the input file must be a 2-column CSV, with the first column containing the texts and the second containing the labels.\n",
    "\n",
    "The `output_loss` is the loss along the text-only path. In this example, the context-label model has a slightly lower loss than the text-only training. (the more disparate the texts, the more helpful providing context will help)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2,000 texts collected.\n",
      "Training new model w/ 2-layer, 128-cell LSTMs\n",
      "Training on 114,564 character sequences.\n",
      "Epoch 1/10\n",
      "895/895 [==============================] - 135s 151ms/step - loss: 2.8489 - context_output_loss: 2.8466 - output_loss: 2.8584\n",
      "Epoch 2/10\n",
      "895/895 [==============================] - 138s 154ms/step - loss: 2.2077 - context_output_loss: 2.2050 - output_loss: 2.2186\n",
      "Epoch 3/10\n",
      "895/895 [==============================] - 145s 162ms/step - loss: 1.9508 - context_output_loss: 1.9485 - output_loss: 1.9598\n",
      "Epoch 4/10\n",
      "895/895 [==============================] - 131s 146ms/step - loss: 1.7791 - context_output_loss: 1.7770 - output_loss: 1.7872\n",
      "Epoch 5/10\n",
      "895/895 [==============================] - 128s 143ms/step - loss: 1.6499 - context_output_loss: 1.6480 - output_loss: 1.6575\n",
      "####################\n",
      "Temperature: 0.2\n",
      "####################\n",
      "Trump the Congress to Trump Trump Americans Comment Committee Contact At Americans Americans Communications to Trump James Americans Comment Post Defore He Was White House Healthcare Senate Comey Puppers to Republican Comey Comey Consition Committee And Americans Committee Demand Trump Administrat\n",
      "\n",
      "Trump tas the Russia to the the the the me the states to the frens to the was the was the went with the does a heckin good boye\n",
      "\n",
      "Trump talks to senate committee in the was the does a heckin good boye the got for the the was the senate charge the was the the the the the the show does a heckin good boye\n",
      "\n",
      "####################\n",
      "Temperature: 0.5\n",
      "####################\n",
      "Donald Trump to say Trump in White House Republicans Trump Administration\n",
      "\n",
      "Trump came net neetings smol the Pudiate for US doggo\n",
      "\n",
      "Trump Jameed to Campaign As No News Fores Comey He Formand Ever Administration\n",
      "\n",
      "####################\n",
      "Temperature: 1.0\n",
      "####################\n",
      "Obana ad’ fight. 'kill states tist the Trump take gusticing what a Campher on Trump, 'thrut save a want' you the hol\n",
      "\n",
      "Neway: Firmonwailly are Billion got\n",
      "\n",
      "Vatters tike is every senting councer Trump-tavet Trump thoush will replicument in gusinsuss\n",
      "\n",
      "Epoch 6/10\n",
      "895/895 [==============================] - 132s 147ms/step - loss: 1.5449 - context_output_loss: 1.5432 - output_loss: 1.5521\n",
      "Epoch 7/10\n",
      "895/895 [==============================] - 126s 141ms/step - loss: 1.4532 - context_output_loss: 1.4515 - output_loss: 1.4601\n",
      "Epoch 8/10\n",
      "895/895 [==============================] - 133s 149ms/step - loss: 1.3715 - context_output_loss: 1.3698 - output_loss: 1.3783\n",
      "Epoch 9/10\n",
      "895/895 [==============================] - 126s 141ms/step - loss: 1.2969 - context_output_loss: 1.2952 - output_loss: 1.3037\n",
      "Epoch 10/10\n",
      "895/895 [==============================] - 126s 141ms/step - loss: 1.2316 - context_output_loss: 1.2300 - output_loss: 1.2384\n",
      "####################\n",
      "Temperature: 0.2\n",
      "####################\n",
      "Trump to send Trump asked to testify for the White House comments and the world to president is the wins the could be a heckin good boye\n",
      "\n",
      "Trump to suggested and the White House delieve the Russia sanctions with Russian and the was president\n",
      "\n",
      "The FCC Calling Congress to Trump Tower Politic President Trump Jeff Sessions Comey Comey Comey Comey Comey to Trump Comey Comey Comey Comey to Congress to Trump Russia investigation\n",
      "\n",
      "####################\n",
      "Temperature: 0.5\n",
      "####################\n",
      "Trump Campaign Accort Donald Trump Controbe To Sussing Is alleged Trump Russia investigation into a heckin stracking governing to take to do a president with Russia has stairs for the happy bamboozle\n",
      "\n",
      "Donald Trump recovers because of Trump to that pool on the White House beans solfise in the would low the servery's statement to beel to serve of the loves frens worker to cover the was to pup the down happy is not to go to hold time of the pass can at suclight and to don't have a place\n",
      "\n",
      "Trump to that runsident to be your a big but need alting to out work\n",
      "\n",
      "####################\n",
      "Temperature: 1.0\n",
      "####################\n",
      "poor is heckins from the minf?\n",
      "\n",
      "Uspeal, Vvvv fasted Trump Exazon offore helpuries\n",
      "\n",
      "Fren Refuse boyn the Thind Broust for Trump West to White House effered from plan in get nueting quitured\n",
      "\n"
     ]
    },
    {
     "ename": "ValueError",
     "evalue": "The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-2-7e78179a1381>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0mtextgen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mtextgen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_from_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnew_model\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_epochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgen_epochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m/usr/local/lib/python3.6/site-packages/textgenrnn-1.1-py3.6.egg/textgenrnn/textgenrnn.py\u001b[0m in \u001b[0;36mtrain_from_file\u001b[0;34m(self, file_path, header, delim, new_model, context, **kwargs)\u001b[0m\n\u001b[1;32m    245\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mnew_model\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    246\u001b[0m             self.train_new_model(\n\u001b[0;32m--> 247\u001b[0;31m                 texts, context_labels=context_labels, **kwargs)\n\u001b[0m\u001b[1;32m    248\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    249\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_on_texts\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtexts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontext_labels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcontext_labels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/site-packages/textgenrnn-1.1-py3.6.egg/textgenrnn/textgenrnn.py\u001b[0m in \u001b[0;36mtrain_new_model\u001b[0;34m(self, texts, context_labels, num_epochs, gen_epochs, batch_size, **kwargs)\u001b[0m\n\u001b[1;32m    220\u001b[0m                             \u001b[0mgen_epochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mgen_epochs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    221\u001b[0m                             \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 222\u001b[0;31m                             **kwargs)\n\u001b[0m\u001b[1;32m    223\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    224\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweights_path\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"textgenrnn_weights_saved.hdf5\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/site-packages/textgenrnn-1.1-py3.6.egg/textgenrnn/textgenrnn.py\u001b[0m in \u001b[0;36mtrain_on_texts\u001b[0;34m(self, texts, context_labels, batch_size, num_epochs, verbose, new_model, gen_epochs, prop_keep, max_gen_length, **kwargs)\u001b[0m\n\u001b[1;32m    162\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    163\u001b[0m         \u001b[0;31m# Keep the text-only version of the model if using context labels\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 164\u001b[0;31m         \u001b[0;32mif\u001b[0m \u001b[0mcontext_labels\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    165\u001b[0m             self.model = Model(inputs=self.model.input[0],\n\u001b[1;32m    166\u001b[0m                                outputs=self.model.output[1])\n",
      "\u001b[0;31mValueError\u001b[0m: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()"
     ]
    }
   ],
   "source": [
    "file_path = \"../datasets/reddit_rarepuppers_politics_2000_context.csv\"\n",
    "\n",
    "textgen.reset()\n",
    "textgen.train_from_file(file_path, new_model=True, num_epochs=10, gen_epochs=5, context=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can use `train_on_texts` with context labels as well by providing a list of context labels of the same size as the texts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "texts = ['Never gonna give you up, never gonna let you down',\n",
    "            'Never gonna run around and desert you',\n",
    "            'Never gonna make you cry, never gonna say goodbye',\n",
    "            'Never gonna tell a lie and hurt you']\n",
    "\n",
    "context_labels = ['Verse 1', 'Verse 1', 'Verse 2', 'Verse 2']\n",
    "\n",
    "textgen.reset()\n",
    "textgen.train_new_model(texts, context_labels=context_labels, max_length=5, gen_epochs=5, num_epochs=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LICENSE\n",
    "\n",
    "MIT License\n",
    "\n",
    "Copyright (c) 2018 Max Woolf\n",
    "\n",
    "Permission is hereby granted, free of charge, to any person obtaining a copy\n",
    "of this software and associated documentation files (the \"Software\"), to deal\n",
    "in the Software without restriction, including without limitation the rights\n",
    "to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n",
    "copies of the Software, and to permit persons to whom the Software is\n",
    "furnished to do so, subject to the following conditions:\n",
    "\n",
    "The above copyright notice and this permission notice shall be included in all\n",
    "copies or substantial portions of the Software.\n",
    "\n",
    "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
    "IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
    "FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n",
    "AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
    "LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n",
    "OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n",
    "SOFTWARE."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
